{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解读 `tvm.tir.transform.LowerDeviceStorageAccessInfo`\n",
    "\n",
    "参考：`tvm/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from pathlib import Path\n",
    "ROOT = Path(\".\").resolve().parents[2]\n",
    "sys.path.extend([f\"{ROOT}/tests\", f\"{ROOT}/src\"])\n",
    "# # from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span\n",
    "from tools.torch_utils import verify_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm.script import tir as T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.register_func(\"tvm.info.mem.global.test_with_head_address\")\n",
    "def mem_info_with_head_address():\n",
    "    return tvm.ir.make_node(\n",
    "        \"MemoryInfo\",\n",
    "        unit_bits=8,\n",
    "        max_simd_bits=32,\n",
    "        max_num_bits=128,\n",
    "        head_address=tvm.tir.call_extern(\"handle\", \"dummy_head_address\"),\n",
    "    )\n",
    "\n",
    "@tvm.register_func(\"tvm.info.mem.global.test_without_head_address\")\n",
    "def mem_info_without_head_address():\n",
    "    return tvm.ir.make_node(\n",
    "        \"MemoryInfo\",\n",
    "        unit_bits=8,\n",
    "        max_simd_bits=32,\n",
    "        max_num_bits=128,\n",
    "        head_address=None,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将 CPU 可见的缓冲区分配替换为 LetStmt\n",
    "\n",
    "对于 CPU 可以访问的范围（例如 hexagon 上的 VTCM），头地址指定了如何访问它，并用于替换 AllocateNode。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BaseCompare(tvm.testing.CompareBeforeAfter):\n",
    "    transform = tvm.tir.transform.LowerDeviceStorageAccessInfo()\n",
    "\n",
    "\n",
    "class TestLowerCPUAccessibleScope(BaseCompare):\n",
    "    \"\"\"Allocate of CPU-visible buffers are replaced by LetStmt\n",
    "\n",
    "    For scopes that are accessible by the CPU (e.g. VTCM on hexagon),\n",
    "    the head address specifies how it should be accessed, and is used\n",
    "    to replace the AllocateNode.\n",
    "    \"\"\"\n",
    "\n",
    "    def before():\n",
    "        ptr = T.allocate([16], \"float32\", scope=\"global.test_with_head_address\")\n",
    "        T.evaluate(ptr)\n",
    "\n",
    "    def expected():\n",
    "        ptr: T.handle(\"float32\", \"global.test_with_head_address\") = T.call_extern(  # noqa: F722\n",
    "            \"handle\", \"dummy_head_address\"\n",
    "        )\n",
    "        T.evaluate(ptr)\n",
    "\n",
    "\n",
    "class TestLowerCPUAccessibleScopeWithDeclBuffer(BaseCompare):\n",
    "    \"\"\"Like TestLowerCPUAccessibleScope, but with a DeclBuffer.\n",
    "\n",
    "    When the Allocate is updated, the DeclBuffer should not contain a\n",
    "    dangling reference.\n",
    "    \"\"\"\n",
    "\n",
    "    def before():\n",
    "        buf = T.decl_buffer(16, \"float32\", scope=\"global.test_with_head_address\")\n",
    "        T.evaluate(buf.data)\n",
    "\n",
    "    def expected():\n",
    "        ptr: T.handle(\"float32\", \"global.test_with_head_address\") = T.call_extern(  # noqa: F722\n",
    "            \"handle\", \"dummy_head_address\"\n",
    "        )\n",
    "        buf = T.decl_buffer(16, \"float32\", scope=\"global.test_with_head_address\", data=ptr)\n",
    "        T.evaluate(ptr)\n",
    "\n",
    "\n",
    "class TestLowerCPUInaccessibleScope(BaseCompare):\n",
    "    \"\"\"Allocate of CPU-visible buffers are replaced by LetStmt\n",
    "\n",
    "    For scopes that are inaccessible by the CPU (e.g. Texture memory\n",
    "    on GPU), the allocate is removed.  All CPU-side references to the\n",
    "    buffer should have been lowered by this point.\n",
    "    \"\"\"\n",
    "\n",
    "    def before():\n",
    "        ptr = T.allocate([16], \"float32\", scope=\"global.test_without_head_address\")\n",
    "        T.evaluate(0)\n",
    "\n",
    "    def expected():\n",
    "        T.evaluate(0)\n",
    "\n",
    "\n",
    "class TestLowerCPUInaccessibleScopeWithDeclBuffer(BaseCompare):\n",
    "    \"\"\"Like TestLowerCPUInaccessibleScope, but with a DeclBuffer\n",
    "\n",
    "    When the Allocate is removed, the DeclBuffer should not contain a\n",
    "    dangling reference.\n",
    "    \"\"\"\n",
    "\n",
    "    def before():\n",
    "        buf = T.decl_buffer(16, \"float32\", scope=\"global.test_without_head_address\")\n",
    "        T.evaluate(0)\n",
    "\n",
    "    def expected():\n",
    "        T.evaluate(0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
